//hpatchi.c
// patch tool for HPatchLite
//
/*
 The MIT License (MIT)
 Copyright (c) 2020-2022 HouSisong All Rights Reserved.
 */

#include <string.h>
#include <stdio.h>  //fprintf
#include "HDiffPatch/file_for_patch.h"
#include "sys.h"
#include "HDiffPatch/libHDiffPatch/HPatchLite/hpatch_lite.h"
#include "tuz_dec.h"
#include "heap.h"

#define _CompressPlugin_tuz
unsigned long outsum = 0;

int hpatchi_patch(hpatchi_listener_t* listener, hpi_compressType compress_type, hpi_pos_t newSize,
    hpi_pos_t uncompressSize, size_t patchCacheSize);

typedef enum THPatchiResult {
    HPATCHI_SUCCESS = 0,
    HPATCHI_OPTIONS_ERROR,
    HPATCHI_PATHTYPE_ERROR,
    HPATCHI_OPENREAD_ERROR,
    HPATCHI_OPENWRITE_ERROR,
    HPATCHI_FILEREAD_ERROR,// 5
    HPATCHI_FILEWRITE_ERROR,
    HPATCHI_FILEDATA_ERROR,
    HPATCHI_FILECLOSE_ERROR,
    HPATCHI_MEM_ERROR,
    HPATCHI_COMPRESSTYPE_ERROR, //10
    HPATCHI_DECOMPRESSER_DICT_ERROR,
    HPATCHI_DECOMPRESSER_OPEN_ERROR,
    HPATCHI_DECOMPRESSER_CLOSE_ERROR,
    HPATCHI_PATCH_OPEN_ERROR = 20,
    HPATCHI_PATCH_ERROR,
} THPatchiResult;

#define  check_on_error(errorType) { \
    if (result==HPATCHI_SUCCESS) result=errorType; if (!_isInClear){ goto clear; } }
#define  check(value,errorType,errorInfo) { \
    if (!(value)){ LOG_ERR(errorInfo " ERROR!\n"); check_on_error(errorType); } }

#define _free_mem(p) { if (p) { vPortFree(p); p=0; } }

typedef struct TPatchListener {
    hpatchi_listener_t      base;
    int                     result;
    hpatch_TStreamInput* old_file;
    FILE* new_file;
} TPatchListener;

static hpi_BOOL _tuz_TStream_decompress(hpi_TInputStreamHandle diffStream, hpi_byte* out_part_data, hpi_size_t* data_size) {
    return tuz_STREAM_END >= tuz_TStream_decompress_partial((tuz_TStream*)diffStream, out_part_data, data_size);
}
static size_t _tuz_TStream_getReservedMemSize(hpi_TInputStreamHandle codeStream, hpi_TInputStream_read readCode) {
    const tuz_size_t dictSize = tuz_TStream_read_dict_size(codeStream, readCode);
    if (((tuz_size_t)(dictSize - 1)) >= tuz_kMaxOfDictSize)
        return 0;//error
    return dictSize;
}
static hpi_BOOL _do_readFile(hpi_TInputStreamHandle diffStream, hpi_byte* out_data, hpi_size_t* data_size) {
    static unsigned int diffcnt = 0;
    for (int i = 0; i < *data_size; i++)
    {
        out_data[i] = diff[diffcnt++];
    }

    return hpi_TRUE;
}
static hpi_BOOL _do_readOld(struct hpatchi_listener_t* listener, hpi_pos_t read_from_pos, hpi_byte* out_data, hpi_size_t data_size) {
    TPatchListener* self = (TPatchListener*)listener;
    return self->old_file->read(self->old_file, read_from_pos, out_data, out_data + data_size);
}
static hpi_BOOL _do_writeNew(struct hpatchi_listener_t* listener, const hpi_byte* data, hpi_size_t data_size) {
    TPatchListener* self = (TPatchListener*)listener;
    for (int i = 0; i < data_size; i++)
    {
        outsum += data[i];
    }
    return 1;
}
int hpatchi_patch(hpatchi_listener_t* listener, hpi_compressType compress_type, hpi_pos_t newSize,
    hpi_pos_t uncompressSize, size_t patchCacheSize) {
    int     result = HPATCHI_SUCCESS;
    int     _isInClear = hpatch_FALSE;
    hpi_byte* pmem = 0;
    hpi_byte* temp_cache;
    // patchCacheSize 1/4 for decompress input buf, 3/4 for patch buf
    const size_t    decBufSize = (patchCacheSize >= 4) ? (patchCacheSize >> 2) : 1;
    size_t          patchBufSize = (patchCacheSize - decBufSize) >> 1 << 1;
#ifdef _CompressPlugin_tuz
    tuz_TStream     tuzStream;
#endif
#ifdef _CompressPlugin_zlib
    zlib_TStream    zlibStream;
#endif
#ifdef _CompressPlugin_lzma
    lzma_TStream    lzmaStream;
#endif
#ifdef _CompressPlugin_lzma2
    lzma2_TStream   lzma2Stream;
#endif

    assert(patchCacheSize == (hpi_size_t)patchCacheSize);
    {//get decompresser
        switch (compress_type) {
        case hpi_compressType_no: { // memory size: patchCacheSize
            printf("hpatchi run with decompresser: \"\"\n");
            patchBufSize = patchCacheSize;
            printf("  requirements memory size: (must) %" PRIu64 " + (cache) %" PRIu64 "\n",
                (hpatch_StreamPos_t)0, (hpatch_StreamPos_t)(decBufSize + patchBufSize));
            pmem = (hpi_byte*)pvPortMalloc(patchCacheSize);
            check(pmem, HPATCHI_MEM_ERROR, "alloc cache memory");
            temp_cache = pmem;
        } break;
#ifdef _CompressPlugin_tuz
        case hpi_compressType_tuz: { // requirements memory size: dictSize + patchCacheSize
            size_t  decompressMemSize;
            size_t  reservedMemSize;
            printf("hpatchi run with decompresser: \"tuz\"\n");
            assert(decBufSize == (tuz_size_t)decBufSize);
            reservedMemSize = _tuz_TStream_getReservedMemSize(listener->diff_data, listener->read_diff);
            check(reservedMemSize > 0, HPATCHI_DECOMPRESSER_DICT_ERROR, "tuz_TStream_read_dict_size() dict size");

            decompressMemSize = reservedMemSize + decBufSize;
            printf("  requirements memory size: (must) %" PRIu64 " + (cache) %" PRIu64 "\n",
                (hpatch_StreamPos_t)reservedMemSize, (hpatch_StreamPos_t)(decBufSize + patchBufSize));
            printf("\r\nmalloc memory %d\r\n",decompressMemSize + patchBufSize);
            pmem = (hpi_byte*)pvPortMalloc(decompressMemSize + patchBufSize);
            check(pmem, HPATCHI_MEM_ERROR, "alloc cache memory");

            check(tuz_OK == tuz_TStream_open(&tuzStream, listener->diff_data, listener->read_diff,
                pmem, (tuz_size_t)reservedMemSize, (tuz_size_t)decBufSize),
                HPATCHI_DECOMPRESSER_OPEN_ERROR, "tuz_TStream_open()");
            temp_cache = pmem + decompressMemSize;

            listener->diff_data = &tuzStream;
            listener->read_diff = _tuz_TStream_decompress;
        } break;
#endif
#ifdef _CompressPlugin_zlib
        case hpi_compressType_zlib: { // requirements memory size: 7KB + dictSize + patchCacheSize
            _openDecompresser(&zlibStream, "zlib", listener, _zlib_TStream_init,
                _zlib_TStream_getReservedMemSize, _zlib_TStream_open, _zlib_TStream_decompress);
        } break;
#endif
#ifdef _CompressPlugin_lzma
        case hpi_compressType_lzma: { // requirements memory size: 8KB--32KB + dictSize + patchCacheSize
            _openDecompresser(&lzmaStream, "lzma", listener, _lzma_TStream_init,
                _lzma_TStream_getReservedMemSize, _lzma_TStream_open, _lzma_TStream_decompress);
        } break;
#endif
#ifdef _CompressPlugin_lzma2
        case hpi_compressType_lzma2: { // requirements memory size: 8KB--32KB + dictSize + patchCacheSize
            _openDecompresser(&lzma2Stream, "lzma2", listener, _lzma2_TStream_init,
                _lzma2_TStream_getReservedMemSize, _lzma2_TStream_open, _lzma2_TStream_decompress);
        } break;
#endif
        default: {
            LOG_ERR("unknow compress_type \"%d\" ERROR!\n", (int)compress_type);
            check(hpatch_FALSE, HPATCHI_COMPRESSTYPE_ERROR, "diff info");
        }
        }
    }

    check(hpatch_lite_patch(listener, newSize, temp_cache, (hpi_size_t)patchBufSize),
        HPATCHI_PATCH_ERROR, "hpatch_lite_patch() run");

clear:
    _isInClear = hpatch_TRUE;
#ifdef _CompressPlugin_zlib
    if (hpi_compressType_zlib == compress_type) _closeDecompresser(&zlibStream, "zlib", _zlib_TStream_close);
#endif
#ifdef _CompressPlugin_lzma
    if (hpi_compressType_lzma == compress_type) _closeDecompresser(&lzmaStream, "lzma", _lzma_TStream_close);
#endif
#ifdef _CompressPlugin_lzma2
    if (hpi_compressType_lzma2 == compress_type) _closeDecompresser(&lzma2Stream, "lzma2", _lzma2_TStream_close);
#endif
    _free_mem(pmem);
    return result;
}

int hpatchi(size_t patchCacheSize) {
    int     result = HPATCHI_SUCCESS;
    int     _isInClear = hpatch_FALSE;
    hpatch_TFileStreamOutput    newData;
    hpatch_TFileStreamInput     diffData;
    hpatch_TFileStreamInput     oldData;
    TPatchListener      patchListener;
    hpi_compressType    compress_type;
    hpi_pos_t           uncompressSize;

    patchListener.result = HPATCHI_SUCCESS;
    hpatch_TFileStreamInput_init(&oldData);
    hpatch_TFileStreamInput_init(&diffData);
    hpatch_TFileStreamOutput_init(&newData); 

    check(hpatch_TFileStreamInput_open(&oldData, ""),
        HPATCHI_OPENREAD_ERROR, "open oldFile for read");

    patchListener.base.diff_data = "";
    patchListener.base.read_diff = _do_readFile;
    {//open diff info
        hpi_pos_t newSize;
        if (!hpatch_lite_open(patchListener.base.diff_data, patchListener.base.read_diff,
            &compress_type, &newSize, &uncompressSize))
            check(hpatch_FALSE, HPATCHI_PATCH_OPEN_ERROR, "hpatch_lite_open() run");

        printf("newDataSize : %" PRIu64 "\r\n", (hpatch_StreamPos_t)newSize);
        check(hpatch_TFileStreamOutput_open(&newData, "", newSize),
            HPATCHI_OPENWRITE_ERROR, "open out newFile for write");
    }

    patchListener.old_file = &oldData.base;
    patchListener.base.read_old = _do_readOld;
    //patchListener.new_file = "";
    patchListener.base.write_new = _do_writeNew;
    {
        int patchret = hpatchi_patch(&patchListener.base, compress_type, (hpi_pos_t)newData.base.streamSize,
            uncompressSize, patchCacheSize);
        if (patchret != HPATCHI_SUCCESS) {
            check(patchListener.result == HPATCHI_SUCCESS, patchListener.result, "patchListener");
            check(!oldData.fileError, HPATCHI_FILEREAD_ERROR, "oldFile read");
            check(hpatch_FALSE, patchret, "hpatchi_patch() run");
        }
        printf("  patch ok!\r\n");
    }

clear:
    _isInClear = hpatch_TRUE;
    printf("out sum=%X\r\n", outsum);
    return result;
}

